/// Log a warning, and if a debug build then panic.
macro_rules! debug_panic {
    ($($x:tt)+) => {
        log::warn!($($x)+);
        #[cfg(debug_assertions)]
        panic!($($x)+);
    };
}

/// Log a message once at level `lvl_once`, and any later log messages from this line at level
/// `lvl_remaining`.
///
/// A log target is not supported. The string "(LOG_ONCE)" will be prepended to the message to
/// indicate that future messages won't be logged at `lvl_once`.
///
/// ```
/// # use log::Level;
/// # use shadow_rs::log_once_at_level;
/// log_once_at_level!(Level::Warn, Level::Debug, "Unexpected flag {}", 10);
/// ```
#[allow(unused_macros)]
#[macro_export]
macro_rules! log_once_at_level {
    ($lvl_once:expr, $lvl_remaining:expr, $str:literal $($x:tt)*) => {
        // don't do atomic operations if this log statement isn't enabled
        if log::log_enabled!($lvl_once) || log::log_enabled!($lvl_remaining) {
            static HAS_LOGGED: std::sync::atomic::AtomicBool =
                std::sync::atomic::AtomicBool::new(false);

            // TODO: doing just a `load()` might be faster in the typical case, but would need to
            // have performance metrics to back that up
            match HAS_LOGGED.compare_exchange(
                false,
                true,
                std::sync::atomic::Ordering::Relaxed,
                std::sync::atomic::Ordering::Relaxed,
            ) {
                Ok(_) => log::log!($lvl_once, "(LOG_ONCE) {}", format_args!($str $($x)*)),
                Err(_) => log::log!($lvl_remaining, "(LOG_ONCE) {}", format_args!($str $($x)*)),
            }
        }
    };
}

/// Log a message once at level `lvl_once` for each distinct value, and any
/// later log messages from this line with an already-logged value at level
/// `lvl_remaining`.
///
/// A log target is not supported. The string "(LOG_ONCE)" will be prepended to
/// the message to indicate that future messages won't be logged at `lvl_once`.
///
/// The fast-path (where the given value has already been logged) aquires a
/// read-lock and looks up the value in a hash table.
///
/// ```
/// # use log::Level;
/// # use shadow_rs::log_once_per_value_at_level;
/// # let unknown_flag: i32 = 0;
/// log_once_per_value_at_level!(unknown_flag, i32, Level::Warn, Level::Debug, "Unknown flag value {unknown_flag}");
/// ```
#[allow(unused_macros)]
#[macro_export]
macro_rules! log_once_per_value_at_level {
    ($value:expr, $t:ty, $lvl_once:expr, $lvl_remaining:expr, $str:literal $($x:tt)*) => {
        // don't do atomic operations if this log statement isn't enabled
        if log::log_enabled!($lvl_once) || log::log_enabled!($lvl_remaining) {
            use $crate::utility::once_set::OnceSet;
            static LOGGED_SET : OnceSet<$t> = OnceSet::new();

            let level = if LOGGED_SET.insert($value) {
                $lvl_once
            } else {
                $lvl_remaining
            };
            log::log!(level, "(LOG_ONCE) {}", format_args!($str $($x)*))
        }
    };
}

/// Log a message once at warn level, and any later log messages from this line at debug level. A
/// log target is not supported. The string "(LOG_ONCE)" will be prepended to the message to
/// indicate that future messages won't be logged at warn level.
///
/// ```ignore
/// warn_once_then_debug!("Unexpected flag {}", 10);
/// ```
#[allow(unused_macros)]
macro_rules! warn_once_then_debug {
    ($($x:tt)+) => {
        log_once_at_level!(log::Level::Warn, log::Level::Debug, $($x)+);
    };
}

/// Log a message once at warn level, and any later log messages from this line at trace level. A
/// log target is not supported. The string "(LOG_ONCE)" will be prepended to the message to
/// indicate that future messages won't be logged at warn level.
///
/// ```ignore
/// warn_once_then_trace!("Unexpected flag {}", 10);
/// ```
#[allow(unused_macros)]
macro_rules! warn_once_then_trace {
    ($($x:tt)+) => {
        log_once_at_level!(log::Level::Warn, log::Level::Trace, $($x)+);
    };
}

/// Implements logging functions that were generated by the `log_syscall` macro.
pub struct SyscallLogger;

/// Creates a logging function. This is written so that the macro can be called from within an
/// `impl` block, ideally directly before the syscall function is defined. See the macro definition
/// for the exact argument types that must be provided to the generated function. The macro itself
/// takes the syscall name, the return type, and the argument types.
///
/// The macro:
///
/// ```ignore
/// log_syscall!(close, /* rv */ c_int, /* fd */ c_int);
/// ```
///
/// expands to something like (excluding some extra boilerplate):
///
/// ```ignore
/// impl SyscallLogger {
///     pub fn close(...) -> std::io::Result<()> { ... }
/// }
/// ```
///
/// This generated function can later be called using:
///
/// ```ignore
/// SyscallLogger::close(...)?;
/// ```
macro_rules! log_syscall {
    ($name:ident, $rv:ty $(,)?) => {
        log_syscall!($name, $rv,,);
    };
    ($name:ident, $rv:ty, $($args:ty),* $(,)?) => {
        paste::paste! { log_syscall!([< _syscall_logger_ $name >]; $name, $rv, $($args),*); }
    };
    ($const_name:ident; $name:ident, $rv:ty, $($args:ty),*) => {
        // We use a constant as a hack so that we can do "impl SyscallLogger { ... }" while already
        // inside a "impl SyscallHandler { ... }" block. Apparently they may make this a hard error
        // (with no way to opt-out with an `allow`) in the future:
        // https://github.com/rust-lang/rust/issues/120363
        #[doc(hidden)]
        #[allow(non_upper_case_globals)]
        #[allow(non_local_definitions)]
        const $const_name : () = {
            impl crate::utility::macros::SyscallLogger {
                pub fn $name(
                    writer: impl std::io::Write,
                    args: [shadow_shim_helper_rs::syscall_types::SyscallReg; 6],
                    rv: &crate::host::syscall::types::SyscallResult,
                    fmt: crate::host::syscall::formatter::FmtOptions,
                    tid: crate::host::thread::ThreadId,
                    mem: &crate::host::memory_manager::MemoryManager,
                ) -> std::io::Result<()>
                {
                    let syscall_args = <crate::host::syscall::formatter::SyscallArgsFmt::<$($args),*>>::new(args, fmt, mem);
                    let syscall_rv = crate::host::syscall::formatter::SyscallResultFmt::<$rv>::new(&rv, args, fmt, mem);

                    crate::host::syscall::formatter::write_syscall(
                        writer,
                        &crate::host::syscall::handler::Worker::current_time().unwrap(),
                        tid,
                        std::stringify!($name),
                        syscall_args,
                        syscall_rv,
                    )
                }
            }
        };
    };
}

/// Returns `None` if any field is not aligned, or if the bytes slice is too small to contain all
/// fields.
macro_rules! field_project {
    ($bytes:expr, $type:ty, $field1:ident) => {
        field_project!($bytes, $type, ($field1,)).map(|x| x.0)
    };
    ($bytes:expr, $type:ty, ($field1:ident,)) => {
        field_project!(@ $bytes, $type, ($field1: A))
    };
    ($bytes:expr, $type:ty, ($field1:ident, $field2:ident)) => {
        field_project!(@ $bytes, $type, ($field1: A), ($field2: B))
    };
    ($bytes:expr, $type:ty, ($field1:ident, $field2:ident, $field3:ident)) => {
        field_project!(@ $bytes, $type, ($field1: A), ($field2: B), ($field3: C))
    };
    (@ $bytes:expr, $type:ty, $(($field:ident: $generic:ident)),*) => {{
        // perform early type checking; we need `MaybeUninit<u8>` rather than just `u8`, otherwise
        // this macro could be used to write uninitialized padding bytes to a `u8` slice
        let bytes: &mut [std::mem::MaybeUninit<u8>] = $bytes;

        const UNINIT: *const $type = std::mem::MaybeUninit::uninit().as_ptr();

        const fn size_of_pointee<T>(_x: *const T) -> usize {
            std::mem::size_of::<T>()
        }

        // This function is needed to:
        // - ensure the type is `Pod`
        // - link the lifetime of `bytes` to the return value's lifetime (we don't want to return a
        //   'static lifetime by accident)
        // - return the correct type for the field, which afaik is only available through the
        //   `addr_of` macro
        fn field_project<$( $generic: shadow_pod::Pod ),*>(
            bytes: &mut [std::mem::MaybeUninit<u8>],
            _for_type_coercion: ($( *const $generic ),*,)
        ) -> Option<($( &mut std::mem::MaybeUninit<$generic> ),*,)> {
            // the byte ranges of each field
            const RANGES: &[std::ops::Range<usize>] = &[ $( {
                const OFFSET: usize = std::mem::offset_of!($type, $field);
                const SIZE: usize = size_of_pointee(unsafe { std::ptr::addr_of!((*UNINIT).$field) });
                OFFSET..(OFFSET+SIZE)
            } ),* ];

            // check that no byte ranges are overlapping
            const {
                let mut i = 0;
                while i < RANGES.len() {
                    let mut j = i+1;
                    while j < RANGES.len() {
                        if RANGES[i].start < RANGES[j].end && RANGES[j].start < RANGES[i].end {
                            panic!("Byte ranges overlap");
                        }
                        j += 1;
                    }
                    i += 1;
                }
            }

            // check that no byte ranges have the same start (don't want two mutable references to
            // the same ZST)
            const {
                let mut i = 0;
                while i < RANGES.len() {
                    let mut j = i+1;
                    while j < RANGES.len() {
                        assert!(RANGES[i].start != RANGES[j].start, "Byte ranges overlap (ZST)");
                        j += 1;
                    }
                    i += 1;
                }
            }

            // get the maximum of all byte ranges
            const RANGE_MAX: usize = {
                let mut max = 0;
                let mut i = 0;
                while i < RANGES.len() {
                    if RANGES[i].end > max {
                        max = RANGES[i].end;
                    }
                    i += 1;
                }
                max
            };

            // make sure a field does not exist outside of `bytes`
            if RANGE_MAX > bytes.len() {
                return None;
            }

            let bytes = bytes.as_mut_ptr();

            // return the references to each field as a tuple
            Some(( $( {
                // NOTE: do not access the original 'bytes' slice within this block, otherwise it
                // causes stacked borrows issues
                const OFFSET: usize = std::mem::offset_of!($type, $field);

                // SAFETY: we've already checked that the field offset is within the bounds of the
                // bytes
                let ptr = unsafe { bytes.add(OFFSET) } as *mut std::mem::MaybeUninit<$generic>;
                if !ptr.is_aligned() {
                    return None;
                }
                // SAFETY:
                // - "The pointer must be properly aligned." - checked above
                // - "It must be 'dereferenceable' in the sense defined in the module
                //   documentation." - points to valid memory within a single allocated object, is
                //   non-null
                // - "The pointer must point to an initialized instance of T." - the pointer is a MaybeUninit
                // - "You must enforce Rust’s aliasing rules, since the returned lifetime 'a is
                //   arbitrarily chosen and does not necessarily reflect the actual lifetime of the
                //   data. In particular, while this reference exists, the memory the pointer points
                //   to must not get accessed (read or written) through any other pointer." - the
                //   outer function makes sure that the returned reference has the correct lifetime
                unsafe { ptr.as_mut() }.unwrap()
            } ),*, ))
        }

        // there's no way to find the types of the fields directly, so we need to get values whose
        // types contain the types of the fields and let rust use type inference to cast to the
        // correct types
        let addr_of_fields = ($( const { unsafe { std::ptr::addr_of!((*UNINIT).$field) } } ),*,);
        field_project(bytes, addr_of_fields)
    }};
}

#[cfg(test)]
mod tests {
    // will panic in debug mode
    #[test]
    #[cfg(debug_assertions)]
    #[should_panic]
    fn debug_panic_macro() {
        debug_panic!("Hello {}", "World");
    }

    // will *not* panic in release mode
    #[test]
    #[cfg(not(debug_assertions))]
    fn debug_panic_macro() {
        debug_panic!("Hello {}", "World");
    }

    #[test]
    fn log_once_at_level() {
        // we don't have a logger set up so we can't actually inspect the log output (well we
        // probably could with a custom logger), so instead we just make sure it compiles
        for x in 0..10 {
            log_once_at_level!(log::Level::Warn, log::Level::Debug, "{x}");
        }

        log_once_at_level!(log::Level::Warn, log::Level::Debug, "A");
        log_once_at_level!(log::Level::Warn, log::Level::Debug, "A");

        // expected log output is:
        // Warn: 0
        // Debug: 1
        // Debug: 2
        // ...
        // Warn: A
        // Warn: A
    }

    #[test]
    fn warn_once() {
        warn_once_then_trace!("A");
        warn_once_then_debug!("A");
    }

    #[test]
    fn field_project_1() {
        let mut foo: libc::nlmsghdr = shadow_pod::zeroed();
        let foo_bytes = unsafe { shadow_pod::as_u8_slice_mut(&mut foo) };

        let foo_nlmsg_type = field_project!(foo_bytes, libc::nlmsghdr, nlmsg_type).unwrap();

        foo_nlmsg_type.write(10);

        assert_eq!(foo.nlmsg_type, 10);
    }

    #[test]
    fn field_project_2() {
        let mut foo: libc::nlmsghdr = shadow_pod::zeroed();
        let foo_bytes = unsafe { shadow_pod::as_u8_slice_mut(&mut foo) };

        let (foo_nlmsg_type, foo_nlmsg_flags) =
            field_project!(foo_bytes, libc::nlmsghdr, (nlmsg_type, nlmsg_flags)).unwrap();

        foo_nlmsg_type.write(10);
        foo_nlmsg_flags.write(20);

        // make sure the order we access the fields doesn't matter (no stacked borrows miri errors)
        foo_nlmsg_flags.write(40);
        foo_nlmsg_type.write(30);

        assert_eq!(foo.nlmsg_type, 30);
        assert_eq!(foo.nlmsg_flags, 40);
    }

    #[test]
    fn field_project_type_inference() {
        let mut foo: libc::nlmsghdr = shadow_pod::zeroed();
        let foo_bytes = unsafe { shadow_pod::as_u8_slice_mut(&mut foo) };

        // make sure field_project returns a u16 reference (ideally we'd want a test that uses an
        // incorrect type and makes sure that the code fails to build to make sure that rust's type
        // inference isn't leading to incorrect code, but writing rust tests that check that code
        // fails to compile isn't supported and the workarounds aren't very nice)
        let _nlmsg_type: &mut std::mem::MaybeUninit<u16> =
            field_project!(foo_bytes, libc::nlmsghdr, nlmsg_type).unwrap();
    }

    #[test]
    fn field_project_range() {
        let mut foo: libc::nlmsghdr = shadow_pod::zeroed();
        let foo_bytes = unsafe { shadow_pod::as_u8_slice_mut(&mut foo) };

        // #[repr(C)]
        // pub struct nlmsghdr {
        //     pub nlmsg_len: u32,
        //     pub nlmsg_type: u16,
        //     ...
        assert!(field_project!(&mut foo_bytes[..0], libc::nlmsghdr, nlmsg_type).is_none());
        assert!(field_project!(&mut foo_bytes[..5], libc::nlmsghdr, nlmsg_type).is_none());
        assert!(field_project!(&mut foo_bytes[..6], libc::nlmsghdr, nlmsg_type).is_some());
    }

    #[test]
    fn field_project_align() {
        let mut foo: libc::nlmsghdr = shadow_pod::zeroed();
        let foo_bytes = unsafe { shadow_pod::as_u8_slice_mut(&mut foo) };

        // #[repr(C)]
        // pub struct nlmsghdr {
        //     pub nlmsg_len: u32,
        //     pub nlmsg_type: u16,
        //     ...
        assert!(field_project!(&mut foo_bytes[..], libc::nlmsghdr, nlmsg_type).is_some());
        assert!(field_project!(&mut foo_bytes[1..], libc::nlmsghdr, nlmsg_type).is_none());
    }
}
